import torch
import torchvision
import typing
import numpy as np


class CIFAR10Generation(torch.utils.data.Dataset):
    def __init__(
            self, 
            root: str, 
            lira_indices: list=None,
            noisy_targets: list=None,
        ):
        self.cifar10 = torchvision.datasets.CIFAR10(root=root, train=True, download=True)
        if lira_indices is None:
            self.indices = list(range(len(self.cifar10)))
        else:
            self.indices = lira_indices
        if noisy_targets is not None:
            for i, noisy_target in enumerate(noisy_targets):
                self.cifar10.targets[i] = noisy_target
        self.class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
    
    def __len__(self):
        return len(self.indices)
    
    def __getitem__(self, idx):
        return self.indices[idx], self.class_names[self.cifar10.targets[self.indices[idx]]]


def in_out_split_noisy(
        clean_train_ys: list, 
        seed: int, 
        num_shadow: int, 
        num_canaries: int, 
        fixed_halves: typing.Optional[bool] = None,
    ) -> typing.Tuple[list, list, list]:
    # Everything from here on depends on the seed
    # All indices are relative to the full raw training set
    # All index arrays (except label noise order) are stored sorted in increasing order
    rng = np.random.default_rng(seed=seed)

    num_raw_train_samples = len(clean_train_ys)
    num_classes = 10
    clean_train_ys = torch.from_numpy(np.array(clean_train_ys))

    # 1) IN-OUT splits
    rng_splits_target, rng_splits_shadow, rng = rng.spawn(3)
    # Currently, we are not using any target models. However, keep rng for compatibility if we need them later.
    del rng_splits_target
    # This ensures that every sample is IN in exactly half of all shadow models if all samples were varied.
    # Calculate splits for all training samples, s.t. the membership is independent of the number of canaries
    # If the number of shadow models changes, then everything changes either way
    assert num_shadow % 2 == 0
    shadow_in_indices_t = np.argsort(
        rng_splits_shadow.uniform(size=(num_shadow, num_raw_train_samples)), axis=0
    )[: num_shadow // 2].T
    raw_shadow_in_indices = []
    for shadow_idx in range(num_shadow):
        raw_shadow_in_indices.append(
            torch.from_numpy(np.argwhere(np.any(shadow_in_indices_t == shadow_idx, axis=1)).flatten())
        )
    rng_splits_half, rng_splits_shadow = rng_splits_shadow.spawn(2)  # used later for fixed splits for validation
    del rng_splits_shadow

    # 2) Canary indices
    rng_canaries, rng = rng.spawn(2)
    canary_order = rng_canaries.permutation(num_raw_train_samples)
    del rng_canaries

    # Calculate proper IN indices depending on setting
    shadow_in_indices = []
    # Normal case; all non-canary samples are always IN
    canary_mask = torch.zeros(num_raw_train_samples, dtype=torch.bool)
    canary_mask[canary_order[: num_canaries]] = True

    if fixed_halves is None:
        for shadow_idx in range(num_shadow):
            current_in_mask = torch.zeros(num_raw_train_samples, dtype=torch.bool)
            current_in_mask[raw_shadow_in_indices[shadow_idx]] = True
            current_in_mask[~canary_mask] = True
            shadow_in_indices.append(torch.argwhere(current_in_mask).flatten())
    else:
        # Special case to validate the setting
        # Always only use half of CIFAR10, but either vary by shadow model, or use a fixed half of non-canaries
        if not fixed_halves:
            # Raw shadow indices are already half of the full training data
            shadow_in_indices = raw_shadow_in_indices
        else:
            # Need to calculate a fixed half of non-canaries
            canary_mask = torch.zeros(num_raw_train_samples, dtype=torch.bool)
            canary_mask[canary_order[: num_canaries]] = True
            fixed_membership_full = torch.from_numpy(rng_splits_half.random(num_raw_train_samples) < 0.5)
            for shadow_idx in range(num_shadow):
                current_in_mask = torch.zeros(num_raw_train_samples, dtype=torch.bool)
                # IN: IN canaries and fixed non-canaries
                current_in_mask[raw_shadow_in_indices[shadow_idx]] = True
                current_in_mask[~canary_mask] = False
                current_in_mask[(~canary_mask) & fixed_membership_full] = True
                shadow_in_indices.append(torch.argwhere(current_in_mask).flatten())
    del rng_splits_half

    # 3) Canary transforms
    rng_canary_transforms, rng = rng.spawn(2)
    # 3.1) Noisy labels for all samples
    rng_noise, rng_canary_transforms = rng_canary_transforms.spawn(2)
    label_changes = torch.from_numpy(rng_noise.integers(num_classes - 1, size=num_raw_train_samples))
    noisy_labels = torch.where(label_changes < clean_train_ys, label_changes, label_changes + 1)
    del rng_noise

    del rng

    noisy_targets = clean_train_ys.clone()
    canary_indices = canary_order[: num_canaries]
    noisy_targets[canary_indices] = noisy_labels[canary_indices]

    noisy_targets = list(noisy_targets.cpu().numpy())
    shadow_in_indices = [_.cpu().numpy() for _ in shadow_in_indices]

    return noisy_targets, shadow_in_indices, canary_indices
